{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 注意! 本脚本制作的字典与本目录下的bert_word2idx_extend.json文件所保存的字典并不相同,  因为bert_word2idx_extend.json文件根据我自己的个人需求加了很多词, 本脚本只做演示目的, 其实无须重新制作字典."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 现在已经有准备好的BERT维基百科训练语料, 已经分割为train_wiki.txt和test_wiki.txt\n",
    "# 语料来源: https://github.com/brightmart/nlp_chinese_corpus\n",
    "# 在准备好的文件里, 写成了下面的格式, 每一行是一条string, 可以eval为python dict\n",
    "# 分别对应着两句有着上下文关系的句子,\n",
    "# 示例:\n",
    "# \"{'text1': '眼蛱蝶族（学名：Junoniini）是蛱蝶科蛱蝶亚科中的一个族。', \n",
    "#   'text2': '此分类的物种在始新世末至渐新世初开始形成。'}\"\n",
    "# 在这个项目里, BERT的训练中, 由./BERT/dataset/wiki_dataset.py文件中的脚本读取txt文件,\n",
    "# 并动态随机做Masked LM和next sentence的mini batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在这里主要是演示怎样制作用来训练的字典, 用来做tokenize, 也就是把汉字转换为token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注意! 以下操作可能会很慢, 因为语料比较大"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8667666\n"
     ]
    }
   ],
   "source": [
    "# 加载所有的语料\n",
    "# 注意这里可能会很慢, 可能需要等到5分钟\n",
    "with open(\"train_wiki.txt\", \"r\", encoding=\"utf-8\") as f:\n",
    "    all_wiki_corpus = [i for i in f.readlines()]\n",
    "with open(\"test_wiki.txt\", \"r\", encoding=\"utf-8\") as f:\n",
    "    all_wiki_corpus += [i for i in f.readlines()]\n",
    "print(len(all_wiki_corpus))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9045453\n"
     ]
    }
   ],
   "source": [
    "# 因为这里上下句有重复的, 所以需要去重, 之后制作字典\n",
    "# 注意这里可能会很慢, 可能需要等到5分钟\n",
    "all_text = []\n",
    "for dic in all_wiki_corpus:\n",
    "    dic = eval(dic)\n",
    "    all_text += [v for _, v in dic.items()]\n",
    "all_text = list(set(all_text))\n",
    "print(len(all_text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'在音乐方面，它更常指作品的类型和风格的更替。'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_text[333]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 我们要制作字典, 首先要制作一个记录所有字出现频率的dict, 然后可以舍去出现频率非常低的字, 也可以不舍去\n",
    "def get_word2tf(corpus_list):\n",
    "    # word2tf是记录字频的dict\n",
    "    word2tf = {}\n",
    "    for text in corpus_list:\n",
    "        for char in list(text):\n",
    "            char = char.lower()\n",
    "            word2tf = update_dic(char, word2tf)\n",
    "    return word2tf\n",
    "\n",
    "def update_dic(char, word2tf):\n",
    "    if word2tf.get(char) is None:\n",
    "        word2tf[char] = 1\n",
    "    else:\n",
    "        word2tf[char] += 1\n",
    "    return word2tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注意这里可能会很慢, 可能需要等到5-10分钟\n",
    "word2tf = get_word2tf(all_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19211\n"
     ]
    }
   ],
   "source": [
    "print(len(word2tf))\n",
    "# 这里可以根据需要舍去字频较低的字, 我们这里不舍去任何东西, 因为只有19211个字..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 我们要训练BERT, 所以我们会有一些特殊的token, 例如#CLS#, #PAD#(用来补足长度)等等,\n",
    "# 所以我们留出前20个token做备用, 实际字的token从序号20开始"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# word2idx是我们将要制作的字典\n",
    "word2idx = {}\n",
    "# 定义一些特殊token\n",
    "pad_index = 0 # 用来补长度和空白\n",
    "unk_index = 1 # 用来表达未知的字, 如果字典里查不到\n",
    "cls_index = 2 #CLS#\n",
    "sep_index = 3 #SEP#\n",
    "mask_index = 4 # 用来做Masked LM所做的遮罩\n",
    "num_index = 5 # (可选) 用来替换语句里的所有数字, 例如把 \"23.9\" 直接替换成 #num#\n",
    "word2idx[\"#PAD#\"] = pad_index\n",
    "word2idx[\"#UNK#\"] = unk_index\n",
    "word2idx[\"#SEP#\"] = sep_index\n",
    "word2idx[\"#CLS#\"] = cls_index\n",
    "word2idx[\"#MASK#\"] = mask_index\n",
    "word2idx[\"#NUM#\"] = num_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19217\n"
     ]
    }
   ],
   "source": [
    "idx = 20\n",
    "for char, v in word2tf.items():\n",
    "    word2idx[char] = idx\n",
    "    idx += 1\n",
    "print(len(word2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注意!! 我们在训练BERT的时候, 实际需要初始化的字向量矩阵的维度是 [19211+20, embedding_dim]\n",
    "# 不要忘记我们预留的20个特殊token的空间"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 写入json\n",
    "with open('bert_word2idx.json', 'w+', encoding='utf-8') as f:\n",
    "    f.write(json.dumps(word2idx, ensure_ascii=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 至此字典制作完毕"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
